#  -*- coding: utf-8 -*-
#  Copyright (c) 2025-2025 Huawei Technologies Co., Ltd.
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#  http://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""
msmodelslim.utils.patch.torch 模块的单元测试
"""
from unittest.mock import Mock
import torch
from torch import nn
import pytest

from msmodelslim.utils.patch.torch import (
    patch_torch,
    _is_torch_nn_module_has_get_submodule,
    _is_torch_nn_module_has_set_submodule,
    _is_torch_has_get_default_device,
    _TORCH_DEFAULT_DEVICE,
)


class TestTorchPatch:

    @staticmethod
    def test_patch_keeps_get_submodule_when_existing():
        if not hasattr(nn.Module, "get_submodule"):
            pytest.skip("当前 PyTorch 无原生 get_submodule，跳过此用例")

        original_method = nn.Module.get_submodule
        patch_torch()
        assert (
            nn.Module.get_submodule is original_method
        ), "补丁不应覆盖原生 get_submodule"

    @staticmethod
    def test_patch_keeps_set_submodule_when_existing():
        if not hasattr(nn.Module, "set_submodule"):
            pytest.skip("当前 PyTorch 无原生 set_submodule，跳过此用例")

        original_method = nn.Module.set_submodule
        patch_torch()
        assert (
            nn.Module.set_submodule is original_method
        ), "补丁不应覆盖原生 set_submodule"

    @staticmethod
    def test_patch_adds_get_default_device_when_missing():
        if hasattr(torch, "get_default_device"):
            delattr(torch, "get_default_device")
        assert not _is_torch_has_get_default_device(), "初始状态应无 get_default_device"

        patch_torch()
        assert _is_torch_has_get_default_device(), "补丁应补充 get_default_device"

        assert torch.get_default_device() == torch.device("cpu"), "初始默认设备应为 CPU"
        target_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        torch.set_default_device(target_device)
        assert torch.get_default_device() == target_device, "默认设备应更新成功"

    @staticmethod
    def test_patch_keeps_get_default_device_when_existing():
        if not hasattr(torch, "get_default_device"):
            pytest.skip("当前 PyTorch 无原生 get_default_device，跳过此用例")

        original_method = torch.get_default_device
        patch_torch()
        assert (
            torch.get_default_device is original_method
        ), "补丁不应覆盖原生 get_default_device"

    @staticmethod
    def test_patch_adds_get_submodule_when_missing(mock_self):
        if hasattr(nn.Module, "get_submodule"):
            delattr(nn.Module, "get_submodule")
        assert not _is_torch_nn_module_has_get_submodule(), "初始状态应无 get_submodule"

        patch_torch()
        assert _is_torch_nn_module_has_get_submodule(), "补丁应补充 get_submodule"

        assert isinstance(mock_self.test_model.get_submodule("conv"), nn.Conv2d)  # 顶层
        assert isinstance(
            mock_self.test_model.get_submodule("inner.linear"), nn.Linear
        )  # 嵌套
        assert (
            mock_self.test_model.get_submodule("inner.non_exist") is None
        )  # 不存在的子模块返回 None

    @staticmethod
    def test_patch_adds_set_submodule_when_missing(mock_self):
        if hasattr(nn.Module, "set_submodule"):
            delattr(nn.Module, "set_submodule")
        assert not _is_torch_nn_module_has_set_submodule(), "初始状态应无 set_submodule"

        patch_torch()
        assert _is_torch_nn_module_has_set_submodule(), "补丁应补充 set_submodule"

        new_linear = nn.Linear(5, 2)
        mock_self.test_model.set_submodule("inner.linear", new_linear)  # 嵌套设置
        assert mock_self.test_model.inner.linear is new_linear, "嵌套子模块应设置成功"

        new_conv = nn.Conv2d(16, 32, 3)
        mock_self.test_model.set_submodule("conv", new_conv)  # 顶层设置
        assert mock_self.test_model.conv is new_conv, "顶层子模块应设置成功"

    @pytest.fixture
    def mock_self(self):
        mock = Mock()
        """测试前准备：保存原始方法+初始化测试模型"""
        # 1. 保存 PyTorch 原生方法（用于测试后恢复）
        mock.original_get_submodule = getattr(nn.Module, "get_submodule", None)
        mock.original_set_submodule = getattr(nn.Module, "set_submodule", None)
        mock.original_get_default_device = getattr(torch, "get_default_device", None)
        mock.original_set_default_device = getattr(torch, "set_default_device", None)

        # 2. 初始化嵌套模型（模拟真实场景的子模块结构）
        class InnerModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(10, 5)
                self.relu = nn.ReLU()

        class TestModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.inner = InnerModel()
                self.conv = nn.Conv2d(3, 16, 3)

        mock.test_model = TestModel()
        yield mock

        # === teardown ===
        for name, original in [
            ("get_submodule", mock.original_get_submodule),
            ("set_submodule", mock.original_set_submodule),
        ]:
            if original is not None:
                setattr(nn.Module, name, original)
            elif hasattr(nn.Module, name):
                delattr(nn.Module, name)

        for name, original in [
            ("get_default_device", mock.original_get_default_device),
            ("set_default_device", mock.original_set_default_device),
        ]:
            if original is not None:
                setattr(torch, name, original)
            elif hasattr(torch, name):
                delattr(torch, name)

        global _TORCH_DEFAULT_DEVICE
        _TORCH_DEFAULT_DEVICE = torch.device("cpu")
